# adapted from https://github.com/salesforce/PCL
from __future__ import print_function
import os
import torch
import argparse
import random
import numpy as np

from torchvision import transforms, datasets
import torchvision.models as models

from dataset_helpers import Voc2007Classification, ImageFolderLowshot

from sklearn.svm import LinearSVC
import sys
sys.path.append('../pretraining/dino')
from utils import load_pretrained_weights
import vision_transformer as vits

VERBOSE=True

def parse_option():
    model_names = sorted(name for name in models.__dict__
                         if name.islower() and not name.startswith("__")
                         and callable(models.__dict__[name]))

    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('data', metavar='DIR',
                        help='path to dataset')
    parser.add_argument('--batch-size', type=int, default=64, help='batch size')
    parser.add_argument('--num-workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--cost', type=str, default='0.5')
    parser.add_argument('--seed', default=0, type=int)

    # model definition
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                        help='model architecture: ' +
                             ' | '.join(model_names) +
                             ' (default: resnet50)')
    parser.add_argument('--pretrained', default='', type=str,
                        help='path to pretrained checkpoint')
    parser.add_argument('--dataset', default='voc2007', type=str,
                        help='dataset name')
    # dataset
    parser.add_argument('--low-shot', default=False, action='store_true', help='whether to perform low-shot training.')

    opt = parser.parse_args()

    opt.num_class = {'voc2007':20, 'places205':205,'flowers102':102, 'cars196':196, 'herba19':683}[opt.dataset]

    # if low shot experiment, do 5 random runs
    if opt.low_shot:
        opt.n_run = 5
    else:
        opt.n_run = 1
    return opt


def calculate_ap(rec, prec):
    """
    Computes the AP under the precision recall curve.
    """
    rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1)
    z, o = np.zeros((1, 1)), np.ones((1, 1))
    mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z))
    for i in range(len(mpre) - 2, -1, -1):
        mpre[i] = max(mpre[i], mpre[i + 1])

    indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1
    ap = 0
    for i in indices:
        ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i]
    return ap

def get_precision_recall(targets, preds):
    """
    [P, R, score, ap] = get_precision_recall(targets, preds)
    Input    :
        targets  : number of occurrences of this class in the ith image
        preds    : score for this image
    Output   :
        P, R   : precision and recall
        score  : score which corresponds to the particular precision and recall
        ap     : average precision
    """
    # binarize targets
    targets = np.array(targets > 0, dtype=np.float32)
    tog = np.hstack((
        targets[:, np.newaxis].astype(np.float64),
        preds[:, np.newaxis].astype(np.float64)
    ))
    ind = np.argsort(preds)
    ind = ind[::-1]
    score = np.array([tog[i, 1] for i in ind])
    sortcounts = np.array([tog[i, 0] for i in ind])

    tp = sortcounts
    fp = sortcounts.copy()
    for i in range(sortcounts.shape[0]):
        if sortcounts[i] >= 1:
            fp[i] = 0.
        elif sortcounts[i] < 1:
            fp[i] = 1.
    P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp))
    numinst = np.sum(targets)
    R = np.cumsum(tp) / numinst
    ap = calculate_ap(R, P)
    return P, R, score, ap

@torch.no_grad()
def main():
    args = parse_option()

    random.seed(args.seed)
    np.random.seed(args.seed)

    mean = [0.485, 0.456, 0.406] if not 'simclr' in args.pretrained else [0,0,0]
    std = [0.229, 0.224, 0.225] if not 'simclr' in args.pretrained else [1,1,1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.dataset == 'voc2007':
        train_dataset = Voc2007Classification(args.data, set='trainval',transform = transform)
        val_dataset = Voc2007Classification(args.data, set='test',transform = transform)
    else:
        val_dataset = ImageFolderLowshot(args.data, dataname=args.dataset, set='val', transform = transform)
        train_dataset = ImageFolderLowshot(args.data, dataname=args.dataset, set='train', transform = transform)


    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)

    # create model
    if VERBOSE:
        print("=> creating model '{}'".format(args.arch))
    if 'deit' not in args.arch:
        model = models.__dict__[args.arch](num_classes=128)
    else:
        model = vits.__dict__[args.arch](patch_size=16, num_classes=0)

    # load from pre-trained
    if 'deit' not in args.arch:
        if args.pretrained:
            if os.path.isfile(args.pretrained):
                if VERBOSE:
                    print("=> loading checkpoint '{}'".format(args.pretrained))
                checkpoint = torch.load(args.pretrained, map_location="cpu")
                if 'state_dict' in checkpoint.keys():
                    state_dict = checkpoint['state_dict']
                elif 'teacher' in checkpoint:
                    state_dict = checkpoint['teacher']
                else:
                    state_dict = checkpoint
                state_dict = {k.replace('module.', ''):v for k,v in state_dict.items()}
                state_dict = {k.replace('backbone.', ''):v for k,v in state_dict.items()}
                # rename pre-trained keys
                if any(['encoder_q' in k for k in state_dict.keys()]):
                    state_dict = {k.replace('encoder_q.',''):v for k,v in state_dict.items() if 'encoder_q' in k}
                for k in list(state_dict.keys()):
                    if k.startswith('fc'):
                        del state_dict[k]
                model.load_state_dict(state_dict, strict=False)
                model.fc = torch.nn.Identity()
                if VERBOSE:
                    print("=> loaded pre-trained model '{}'".format(args.pretrained))
            else:
                if VERBOSE:
                    print("=> no checkpoint found at '{}'".format(args.pretrained))
    else:
        print(f"Model {args.arch} built.")
        # load weights to evaluate
        load_pretrained_weights(model, args.pretrained, 'teacher' , args.arch, 16)

    model.cuda()
    model.eval()

    test_feats = []
    test_labels = []
    if VERBOSE:
        print('==> calculate test features')
    for idx, (images, target) in enumerate(val_loader):
        images = images.cuda(non_blocking=True)
        feat = model(images)
        feat = feat.detach().cpu()
        test_feats.append(feat)
        test_labels.append(target)

    test_feats = torch.cat(test_feats,0).numpy()
    test_labels = torch.cat(test_labels,0).numpy()

    test_feats_norm = np.linalg.norm(test_feats, axis=1)
    test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis]

    result={}
    result_acc={}

    if args.low_shot:
        k_list = [4]
        # k_list = [1,2,4,8,16] # number of samples per-class for low-shot classifcation
    else:
        k_list = ['full']

    for k in k_list:
        cost_list = args.cost.split(',')
        result_k = np.zeros(len(cost_list))
        result_k_acc = np.zeros(len(cost_list))
        for i, cost in enumerate(cost_list):
            cost = float(cost)
            avg_map = []
            avg_acc = []
            per_cls_accs = []
            for run in range(args.n_run):
                if args.low_shot: # sample k-shot training data
                    if VERBOSE:
                        print('==> re-sampling training data')
                    train_dataset.convert_low_shot(k)
                if VERBOSE:
                    print(len(train_dataset))

                train_loader = torch.utils.data.DataLoader(
                    train_dataset, batch_size=args.batch_size, shuffle=False,
                    num_workers=args.num_workers, pin_memory=True)

                train_feats = []
                train_labels = []
                if VERBOSE:
                    print('==> calculate train features')
                for idx, (images, target) in enumerate(train_loader):
                    images = images.cuda(non_blocking=True)
                    feat = model(images)
                    feat = feat.detach().cpu()

                    train_feats.append(feat)
                    train_labels.append(target.cpu())

                train_feats = torch.cat(train_feats,0).numpy()
                train_labels = torch.cat(train_labels,0).numpy()

                train_feats_norm = np.linalg.norm(train_feats, axis=1)
                train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
                if VERBOSE:
                    print('==> training SVM Classifier')
                cls_ap = np.zeros((args.num_class, 1))
                test_labels[test_labels==0] = -1
                train_labels[train_labels==0] = -1
                preds = []
                trues = np.zeros(len(test_labels), dtype=np.int)
                for cls in range(args.num_class):
                    clf = LinearSVC(
                        C=cost, class_weight={1: 2, -1: 1}, intercept_scaling=1.0,
                        penalty='l2', loss='squared_hinge', tol=1e-4,
                        dual=True, max_iter=2000, random_state=0)
                    clf.fit(train_feats, train_labels[:,cls])

                    prediction = clf.decision_function(test_feats)
                    P, R, score, ap = get_precision_recall(test_labels[:,cls], prediction)
                    subset = test_labels[:, cls] > 0
                    trues[subset] = cls
                    preds.append(prediction)
                    cls_ap[cls][0] = ap*100
                mean_ap = np.mean(cls_ap, axis=0)
                avg_map.append(mean_ap)
                # class_accs = [np.mean(np.argmax(np.stack(preds), 0) == c) for c in np.unique(trues)]

                acc = np.mean(np.argmax(np.stack(preds), 0) == trues)
                if VERBOSE:
                    print('==> Run%d mAP is %.2f: '%(run, mean_ap))
                    print('==> Run%d Acc is %.2f: '%(run, acc*100))
                avg_acc.append(acc*100)
                # per_cls_accs.append(np.mean(class_accs)*100)

            avg_map = np.asarray(avg_map)
            if VERBOSE:
                print('Cost:%.2f - Average ap is: %.2f' %(cost, avg_map.mean()))
                print('Cost:%.2f - Std is: %.2f' %(cost, avg_map.std()))
            result_k[i]=avg_map.mean()
            result_k_acc[i] = np.mean(avg_acc)
        result[k] = result_k.max()
        result_acc[k] = result_k_acc.max()
    print("mAP:")
    print({k:np.round(v,1) for k,v in result.items()})
    print("="*60)
    print("Accuracy:")
    print({k:np.round(v,1) for k,v in result_acc.items()})

if __name__ == '__main__':
    main()

